import os
import pandas as pd
import numpy as np
import pickle
import json

try:
    import datasets  # HuggingFace datasets (optional dependency)
    _huggingface_available = True
except ImportError:
    _huggingface_available = False


def load_dataset(path: str, format: str = None, **kwargs):
    """
    Load a dataset from disk in various common formats.

    Args:
        path (str): Path to the dataset file or directory.
        format (str): Explicit format hint ('csv', 'npy', 'npz', 'pickle', etc.). If None, infer from file extension.
        **kwargs: Extra arguments passed to underlying loaders (e.g., sep, header, encoding for CSV).

    Returns:
        Any: The loaded dataset (usually a list, numpy array, or pandas DataFrame).
    """
    if not os.path.exists(path):
        raise FileNotFoundError(f"Dataset not found: {path}")

    ext = os.path.splitext(path)[-1].lower()
    format = format.lower() if format else ext.lstrip('.')

    if format == "csv":
        df = pd.read_csv(path, **kwargs)
        return df.values.tolist()
    
    elif format == "tsv":
        df = pd.read_csv(path, sep="\t", **kwargs)
        return df.values.tolist()

    elif format == "xlsx":
        df = pd.read_excel(path, **kwargs)
        return df.values.tolist()

    elif format == "json":
        with open(path, "r", encoding=kwargs.get("encoding", "utf-8")) as f:
            data = json.load(f)
        return data

    elif format == "parquet":
        df = pd.read_parquet(path, **kwargs)
        return df.values.tolist()

    elif format == "npy":
        return np.load(path, allow_pickle=True)

    elif format == "npz":
        data = np.load(path, allow_pickle=True)
        # Try default key "data", otherwise return whole dict
        return data["data"] if "data" in data else dict(data)

    elif format in {"pkl", "pickle"}:
        with open(path, "rb") as f:
            return pickle.load(f)

    elif format == "huggingface":
        if not _huggingface_available:
            raise ImportError("Please install the 'datasets' package to use HuggingFace datasets.")
        dataset_name = path  # for HuggingFace, path is actually the dataset name
        return datasets.load_dataset(dataset_name, **kwargs)

    else:
        raise ValueError(f"Unsupported dataset format: '{format}'")


def infer_format_from_extension(path: str) -> str:
    """
    Infer dataset format from file extension.

    Args:
        path (str): File path.

    Returns:
        str: Format string.
    """
    ext = os.path.splitext(path)[-1].lower()
    if ext.startswith("."):
        ext = ext[1:]
    return ext
